{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import random\n",
    "import mindspore\n",
    "import mindspore.nn as nn\n",
    "import mindspore.dataset as ds\n",
    "from mindnlp.modules import CRF\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def seed_everything(seed):\n",
    "    random.seed(seed)\n",
    "    os.environ[\"PYTHONHASHSEED\"] = str(seed)\n",
    "    np.random.seed(seed)\n",
    "    mindspore.set_seed(seed)\n",
    "    mindspore.dataset.config.set_seed(seed)\n",
    "\n",
    "# 读取文本，返回词典，索引表，句子，标签\n",
    "def read_data(path):\n",
    "    sentences = []\n",
    "    labels = []\n",
    "    with open(path, 'r', encoding='utf-8') as f:\n",
    "        sent = []\n",
    "        label = []\n",
    "        for line in f:\n",
    "            parts = line.split()\n",
    "            if len(parts) == 0:\n",
    "                if len(sent) != 0:\n",
    "                    sentences.append(sent)\n",
    "                    labels.append(label)\n",
    "                sent = []\n",
    "                label = []\n",
    "            else:\n",
    "                sent.append(parts[0])\n",
    "                label.append(parts[-1])\n",
    "                \n",
    "    return (sentences, labels)\n",
    "\n",
    "# 返回词典映射表、词数字典\n",
    "def get_dict(sentences):\n",
    "    max_number = 1\n",
    "    char_number_dict={}\n",
    "\n",
    "    id_indexs={}\n",
    "    id_indexs['paddding']=0\n",
    "    id_indexs['unknow']=1\n",
    "    \n",
    "    for sent in sentences:\n",
    "        for c in sent:\n",
    "            if c not in char_number_dict:\n",
    "                char_number_dict[c]=0\n",
    "            char_number_dict[c]+=1\n",
    "                \n",
    "    for c,n in char_number_dict.items():\n",
    "        if n>=max_number:\n",
    "            id_indexs[c]=len(id_indexs)\n",
    "            \n",
    "    return char_number_dict, id_indexs\n",
    "\n",
    "def get_entity(decode):\n",
    "    starting=False\n",
    "    p_ans=[]\n",
    "    for i,label in enumerate(decode):\n",
    "        if label > 0:\n",
    "            if label%2==1:\n",
    "                starting=True\n",
    "                p_ans.append(([i],labels_text_mp[label//2]))\n",
    "            elif starting:\n",
    "                p_ans[-1][0].append(i)\n",
    "        else:\n",
    "            starting=False\n",
    "    return p_ans\n",
    "\n",
    "# 处理数据 \n",
    "class Feature(object):\n",
    "    def __init__(self,sent, label):\n",
    "        self.or_text = sent  #文本原句\n",
    "        self.seq_length = len(sent) if len(sent) < Max_Len else Max_Len\n",
    "        self.labels = [LABEL_MAP[c] for c in label][:Max_Len] + [0]*(Max_Len - len(label)) # 标签\n",
    "        self.token_ids = self.tokenizer(sent)[:Max_Len]  + [0]*(Max_Len - len(sent)) #文本token\n",
    "        self.entity = get_entity(self.labels)\n",
    "        \n",
    "    def tokenizer(self, sent):\n",
    "        token_ids = []\n",
    "        for c in sent:\n",
    "            if c in id_indexs.keys():\n",
    "                token_ids.append(id_indexs[c])\n",
    "            else:\n",
    "                token_ids.append(id_indexs['unknow'])\n",
    "        return token_ids\n",
    "\n",
    "class GetDatasetGenerator:\n",
    "    def __init__(self, data):\n",
    "        self.features = [Feature(data[0][i], data[1][i]) for i in range(len(data[0]))]\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.features)\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        feature = self.features[index]\n",
    "        token_ids = feature.token_ids\n",
    "        labels = feature.labels\n",
    "        \n",
    "        return (token_ids, feature.seq_length, labels)\n",
    "    \n",
    "def debug_dataset(dataset):\n",
    "    dataset = dataset.batch(batch_size=16)\n",
    "    for data in dataset.create_dict_iterator():\n",
    "        print(data[\"data\"].shape, data[\"label\"].shape)\n",
    "        break\n",
    "        \n",
    "def get_metric(P_ans, valid):\n",
    "    predict_score = 0 # 预测正确个数\n",
    "    predict_number = 0 # 预测结果个数\n",
    "    totol_number = 0 # 标签个数\n",
    "    for i in range(len(P_ans)):\n",
    "        predict_number += len(P_ans[i])\n",
    "        totol_number += len(valid.features[i].entity)\n",
    "        pred_true = [x for x in valid.features[i].entity if x in P_ans[i]]\n",
    "        predict_score += len(pred_true)\n",
    "    P = predict_score/predict_number if predict_number>0 else 0.\n",
    "    R = predict_score/totol_number if totol_number>0 else 0.\n",
    "    f1=(2*P*R)/(P+R) if (P+R)>0 else 0.\n",
    "    print(f'f1 = {f1}， P(准确率) = {P}, R(召回率) = {R}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTM_CRF(nn.Module):\n",
    "    def __init__(self,embedding_num,embedding_dim,num_labels):\n",
    "        super().__init__()\n",
    "        self.num_labels = num_labels\n",
    "        self.embedding_num = embedding_num\n",
    "        self.embedding_dim = embedding_dim\n",
    "        self.model_name = 'LSTM_CRF'\n",
    "        self.em = nn.Embedding(vocab_size=self.embedding_num,embedding_size=self.embedding_dim, padding_idx=0)\n",
    "        self.bilstm = nn.LSTM(embedding_dim, embedding_dim//2, batch_first=True, bidirectional=True)\n",
    "        self.crf_hidden_fc = nn.Dense(embedding_dim, self.num_labels)\n",
    "        self.crf = CRF(self.num_labels, batch_first=True, reduction='mean')\n",
    "\n",
    "    def construct(self, ids, seq_length=None, labels=None):\n",
    "        seq=self.em(ids)\n",
    "        lstm_feat, _ = self.bilstm(seq)\n",
    "        emissions = self.crf_hidden_fc(lstm_feat)\n",
    "        loss_crf = self.crf(emissions, tags=labels, seq_length=seq_length)\n",
    "        return loss_crf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 42\n",
    "seed_everything(seed)\n",
    "Max_Len = 113\n",
    "Entity = ['PER', 'LOC', 'ORG', 'MISC']\n",
    "labels_text_mp={k:v for k,v in enumerate(Entity)}\n",
    "LABEL_MAP = {'O': 0}\n",
    "for i, e in enumerate(Entity):\n",
    "    LABEL_MAP[f'B-{e}'] = 2 * (i+1) - 1\n",
    "    LABEL_MAP[f'I-{e}'] = 2 * (i+1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# !wget https://data.deepai.org/conll2003.zip --no-check-certificate -O conll2003.zip\n",
    "# !unzip -o conll2003.zip -d conll2003"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = read_data('conll2003/train.txt')\n",
    "test = read_data('conll2003/test.txt')\n",
    "dev = read_data('conll2003/valid.txt')\n",
    "char_number_dict, id_indexs = get_dict(train[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "Epoch = 2\n",
    "batch_size = 16\n",
    "dataset_generator = GetDatasetGenerator(train)\n",
    "dataset = ds.GeneratorDataset(dataset_generator, [\"data\", \"length\", \"label\"], shuffle=False)\n",
    "dataset_train = dataset.batch(batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = LSTM_CRF(embedding_num=len(id_indexs), embedding_dim=256, num_labels=len(Entity)*2+1)\n",
    "optimizer = nn.Adam(model.trainable_params(), learning_rate=0.001)\n",
    "grad_fn = mindspore.value_and_grad(model, None, optimizer.parameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "@mindspore.jit\n",
    "def train_step(token_ids, seq_length, labels):\n",
    "    loss, grads = grad_fn(token_ids, seq_length, labels)\n",
    "    optimizer(grads)\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████| 937/937 [04:59<00:00,  3.13it/s, loss=5.65]\n",
      "100%|█████████████████████████████████████████| 937/937 [04:54<00:00,  3.18it/s, loss=3.71]\n"
     ]
    }
   ],
   "source": [
    "# 训练\n",
    "size = dataset_train.get_dataset_size()\n",
    "steps = size\n",
    "tloss = []\n",
    "for epoch in range(Epoch):\n",
    "    model.set_train()\n",
    "    with tqdm(total=steps) as t:\n",
    "        for batch, (token_ids, seq_length, labels) in enumerate(dataset_train.create_tuple_iterator()):\n",
    "            loss = train_step(token_ids, seq_length, labels)\n",
    "            tloss.append(loss.asnumpy())\n",
    "            t.set_postfix(loss=np.array(tloss).mean())\n",
    "            t.update(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████| 937/937 [06:18<00:00,  2.48it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "f1 = 0.8882163554410817， P(准确率) = 0.8960634013251917, R(召回率) = 0.8805055534278055\n"
     ]
    }
   ],
   "source": [
    "# 预测：train\n",
    "dataset_generator = GetDatasetGenerator(train)\n",
    "dataset = ds.GeneratorDataset(dataset_generator, [\"data\", \"length\", \"label\"], shuffle=False)\n",
    "dataset_train = dataset.batch(batch_size=batch_size)\n",
    "\n",
    "size = dataset_train.get_dataset_size()\n",
    "steps = size\n",
    "decodes=[]\n",
    "model.set_train(False)\n",
    "with tqdm(total=steps) as t:\n",
    "    for batch, (token_ids, seq_length, labels) in enumerate(dataset_train.create_tuple_iterator()):\n",
    "        score, history = model(token_ids, seq_length=seq_length)\n",
    "        best_tags = CRF.post_decode(score, history, seq_length)\n",
    "        decode = [[y.asnumpy().item() for y in x] for x in best_tags]\n",
    "        decodes.extend(list(decode))\n",
    "        t.update(1)\n",
    "        \n",
    "v_pred = [get_entity(x) for x in decodes]\n",
    "get_metric(v_pred, dataset_generator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████| 217/217 [01:31<00:00,  2.37it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "f1 = 0.7575705437026841， P(准确率) = 0.7749032030975008, R(召回率) = 0.7409962975429149\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# 预测：dev\n",
    "dev_dataset_generator = GetDatasetGenerator(dev)\n",
    "dataset_dev = ds.GeneratorDataset(dev_dataset_generator, [\"data\", \"length\", \"label\"], shuffle=False)\n",
    "dataset_dev = dataset_dev.batch(batch_size=batch_size)\n",
    "\n",
    "size = dataset_dev.get_dataset_size()\n",
    "steps = size\n",
    "decodes=[]\n",
    "model.set_train(False)\n",
    "with tqdm(total=steps) as t:\n",
    "    for batch, (token_ids, seq_length, labels) in enumerate(dataset_dev.create_tuple_iterator()):\n",
    "        score, history = model(token_ids, seq_length=seq_length)\n",
    "        best_tags = model.crf.post_decode(score, history, seq_length)\n",
    "        decode = [[y.asnumpy().item() for y in x] for x in best_tags]\n",
    "        decodes.extend(list(decode))\n",
    "        t.update(1)\n",
    "v_pred = [get_entity(x) for x in decodes]\n",
    "get_metric(v_pred, dev_dataset_generator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████| 231/231 [01:33<00:00,  2.47it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "f1 = 0.6655755591925804， P(准确率) = 0.6838565022421524, R(召回率) = 0.6482465462274176\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# 预测：test\n",
    "test_dataset_generator = GetDatasetGenerator(test)\n",
    "dataset_test = ds.GeneratorDataset(test_dataset_generator, [\"data\", \"length\", \"label\"], shuffle=False)\n",
    "dataset_test = dataset_test.batch(batch_size=batch_size)\n",
    "\n",
    "size = dataset_test.get_dataset_size()\n",
    "steps = size\n",
    "decodes_pred=[]\n",
    "model.set_train(False)\n",
    "with tqdm(total=steps) as t:\n",
    "    for batch, (token_ids, seq_length, labels) in enumerate(dataset_test.create_tuple_iterator()):\n",
    "        score, history = model(token_ids, seq_length=seq_length)\n",
    "        best_tags = model.crf.post_decode(score, history, seq_length)\n",
    "        decode = [[y.asnumpy().item() for y in x] for x in best_tags]\n",
    "        decodes_pred.extend(list(decode))\n",
    "        t.update(1)\n",
    "        \n",
    "\n",
    "pred = [get_entity(x) for x in decodes_pred]\n",
    "get_metric(pred, test_dataset_generator)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.16"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "vscode": {
   "interpreter": {
    "hash": "a62cb8bb4abcff3256df5ab1881dc7c3e7803473070698df3ff917df10adcce5"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
